今天我們來實作GAN,簡單複習一下,GAN的Component 有 Generator
以及 Discriminiator
。
而 Generator 任務就是產生圖片來騙過Discriminator , Discriminator的任務就是努力判斷 Generator 所產生圖片的品質。因此,"對抗" 的概念就是從此而來。在實作上,相較於之前的AutoEncoder Model,GAN大部分都是使用 Convolution layer (卷積層) ,而非像之前其他有許多的Dense layer (全連接層)。下圖就是簡易的 Generator 以及 Discriminator 的架構,GAN在training的時候非常不穩定,因此一些layer的設定或者Actvation function的選擇都要注意。
其實 Discriminator 蠻直觀的,其實他就是一個圖片分類器,用以判斷 Generator 產生圖片的好壞。因此,我們可以來看一下如何定義 Discriminator,可以直接先透過Call function來看前向傳播的架構。
就是透過 Conv -> BN -> Con -> BN ... -> Falltern -> Dense ,注意Activation fucntion的部分都是使用Leaky Relu (Leacky Relu跟Relu最大差別就是當值小於0的時候的差別,Relu只要小於0均為0,Leaky Relu則仍會有值)。
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator,self).__init__()
self.conv_1 = layers.Conv2D(64,5,3,'valid')
self.conv_2 = layers.Conv2D(128,5,3,'valid')
self.bn_1 = layers.BatchNormalization()
self.conv_3 = layers.Conv2D(256,5,3,'valid')
self.bn_2 = layers.BatchNormalization()
self.flatten = layers.Flatten()
self.fc_layer = layers.Dense(1)
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv_1(inputs))
x = tf.nn.leaky_relu(self.bn_1(self.conv_2(x),training=training))
x = tf.nn.leaky_relu(self.bn_2(self.conv_3(x),training=training))
x = self.flatten(x)
x = self.fc_layer(x)
return x
Generator的部分,主要為一個圖片產生器,透過一個低維度的matrix,還原成一張正常的圖片。在Generator中
,會使用 tf.layers.Conv2DTranspose
(反卷積) ,簡單來說就是把特徵還原成圖片的概念 (如下圖)
接下來,可以直接先透過Call function來看前向傳播的架構。
Input -> Dense -> Conv Transpose -> BN -> .. -> Tanh
class Generator(keras.Model):
def __init__(self):
super(Generator,self).__init__()
#encoder
self.fc_layer_1 = layers.Dense(3*3*512)
self.conv_1 = layers.Conv2DTranspose(256,3,3,'valid')
self.bn_1 = layers.BatchNormalization()
self.conv_2 = layers.Conv2DTranspose(128,5,2,'valid')
self.bn_2 = layers.BatchNormalization()
self.conv_3 = layers.Conv2DTranspose(3,4,3,'valid')
def call(self, inputs, training=None):
x = self.fc_layer_1(inputs)
x = tf.reshape(x,[-1,3,3,512])
x = tf.nn.leaky_relu(x)
x = self.bn_1(self.conv_1(x),training=training)
x = self.bn_2(self.conv_2(x),training=training)
x = self.conv_3(x)
x = tf.tanh(x)
return x
我們就完成 Generator 和 Discriminator 的建置。接下來就可以做簡單的測試。
x = tf.random.normal([1,64,64,3])
z = tf.random.normal([1,100])
prob = g(x)
print(prob)
out = d(x)
print(out.shape)
今天完成簡易的GAN 模型與建立,明天會跑真實的資料! 感謝大家漫長閱讀。祝大家連假愉快